"""@plugin LuckyThirteenVulnerabilityTesterPlugin
This file contains classes for plugin, which tests server(s) for vulnerability CVE-2013-0169.

@author: Bc. Pavel Soukup
"""

from operator import attrgetter
from xml.etree.ElementTree import Element

from nassl import _nassl
from nassl.ssl_client import SslClient
from sets import Set

from OpenSSL import SSL
from OpenSSL._util import (lib as _lib)
import socket

from sslyze.plugins.common_new_plugin_info import OPENSSL_TO_RFC_NAMES_MAPPING, PROTOCOL_VERSION, AcceptCipher, RejectCipher, TLS_OPENSSL_TO_RFC_NAMES_MAPPING
from sslyze.plugins import plugin_base
from sslyze.plugins.plugin_base import PluginResult
from sslyze.utils.ssl_connection import SSLHandshakeRejected, SSLConnection
from sslyze.utils.thread_pool import ThreadPool

class LuckyThirteenVulnerabilityTesterPlugin(plugin_base.PluginBase):
    """@class LuckyThirteenVulnerabilityTesterPlugin
    
    This class inherited from abstract class PluginBase. Instance of this class tests server for vulnerability CVE-2013-0169 and makes decision if the server is vulnerable.
    """
    interface = plugin_base.PluginInterface(
        "LuckyThirteenVulnerabilityTesterPlugin",
        "Scans the server(s) and checks if requirements for Lucky13 attack are satisfied.")
    interface.add_command(
        command="lucky13_tls",
        help="Tests server(s) for CVE-2013-0169 vulnerability. It uses only protocol TLSv1.1 and TLSv1.2.")
    interface.add_command(
        command="lucky13_dtls",
        help="Tests server(s) for CVE-2013-0169 vulnerability. It uses only protocol DTLSv1 and DTLSv1.2.")
    interface.add_option(
        option="port",
        help="Specify which port you want connect to.",
        dest="port")

    MAX_THREADS=20
    DTLSv1_METHOD = 7
    DTLSv1_2_METHOD = 8

    DTLS_MODULE = {
        DTLSv1_METHOD: "DTLSv1_client_method",
        DTLSv1_2_METHOD: "DTLSv1_2_client_method"}

    def process_task(self, server_connectivity_info, plugin_command, option_dict=None):
        dtls_title = None
        if option_dict and 'verbose' in option_dict.keys():
            verbose_mode = option_dict['verbose']
        else:
            verbose_mode = False

        if plugin_command == 'lucky13_tls':
            (thread_pool, threads) = self.create_thread_pool_for_protocol_tls(server_connectivity_info)
            adv_info = {'protocol': 'TLS'}
        elif plugin_command == 'lucky13_dtls':
            if option_dict and 'port' in option_dict.keys():
                (thread_pool, threads) = self.create_thread_pool_for_protocol_dtls(server_connectivity_info, int(option_dict['port']))
                adv_info = {'protocol': 'DTLS', 'port': option_dict['port']}
            else:
                raise ValueError("LuckyThirteenVulnerabilityTesterPlugin: Missing option --port for command --lucky13-dtls")
        else:
            raise ValueError("LuckyThirteenVulnerabilityTesterPlugin: Unknown command")
        thread_pool.start(nb_threads=threads)

        accept_ciphers = []
        reject_ciphers = []

        if verbose_mode:
            print '  VERBOSE MODE PRINT'
            print '  ------------------'
        for completed_job in thread_pool.get_result():
            (job, cipher_result) = completed_job
            if isinstance(cipher_result, AcceptCipher):
                accept_ciphers.append(cipher_result)
            elif isinstance(cipher_result,RejectCipher):
                reject_ciphers.append(cipher_result)
            else:
                raise ValueError("Unexpected result")
            if verbose_mode:
                cipher_result.print_cipher()

        if verbose_mode:
            print '  ----------------------'
            print '  END VERBOSE MODE PRINT'
            print '  ----------------------'
        for error_job in thread_pool.get_error():
            (_, exception) = error_job
            raise exception

        thread_pool.join()

        support_vulnerable_ciphers = self.get_vulnerable_cipher_set(accept_ciphers)
        is_vulnerable = True \
            if len(support_vulnerable_ciphers) > 0 \
            else False
        
        return LuckyThirteenVulnerabilityTesterResult(server_connectivity_info,plugin_command,option_dict,is_vulnerable,support_vulnerable_ciphers,adv_info)

    def create_thread_pool_for_protocol_tls(self, server_connectivity_info):
        """ Creates and returns instance of ThreadPool class. Adds into ThreadPool new jobs for each cipher suite, which is available for protocol TLS 1.1 and TLS 1.2
    
            Args:
            server_connectivity_info (ServerConnectivityInfo): contains information for connection on server.
        """
        thread_pool=ThreadPool()
        protocols = ['TLSv1.1', 'TLSv1.2']
        cipher_list = []
        for protocol in protocols:
            if self.test_protocol_support(server_connectivity_info, protocol):
                cipher_list = self.get_cipher_list(protocol)
                for cipher in cipher_list:
                    thread_pool.add_job((self._test_ciphersuite,(server_connectivity_info, protocol, cipher)))

        return (thread_pool, min(self.MAX_THREADS,len(cipher)))

    def create_thread_pool_for_protocol_dtls(self,server_connectivity_info, port):
        """ Creates and returns instance of ThreadPool class. Adds into ThreadPool new jobs for each cipher suite, which is available for protocol DTLS 1.0 and/or DTLS 1.2

            Args:
            server_connectivity_info (ServerConnectivityInfo): contains information for connection on server
            port (int): contains port number for connecting comunication.
        """
        dtls_protocols = self.get_support_dtls_protocols_by_client()
        thread_pool=ThreadPool()
        cipher_list = []
        for protocol in dtls_protocols:
            if self.test_dtls_protocol_support(server_connectivity_info, protocol, port):
                cipher_list = self.get_dtls_cipher_list(protocol)
                for cipher in cipher_list:
                    thread_pool.add_job((self._test_dtls_ciphersuite,(server_connectivity_info, protocol, cipher,port)))                    

        return (thread_pool, 1)

    def _test_ciphersuite(self, server_connectivity_info, ssl_version, cipher):
        """This function is used by threads to it investigates with support the cipher suite on server, when TLS protocol(s) is/are tested. Returns instance of class AcceptCipher or RejectCipher.

            Args:
            server_connectivity_info (ServerConnectivityInfo): contains information for connection on server
            ssl_version (str): contains SSL/TLS protocol version, which is used to connect
            cipher (str): contains OpenSSL shortcut for identification cipher suite.
        """
        ssl_conn = server_connectivity_info.get_preconfigured_ssl_connection(override_ssl_version=PROTOCOL_VERSION[ssl_version])
        ssl_conn.set_cipher_list(cipher)
        try:
            ssl_conn.connect()
        except SSLHandshakeRejected as e:
            cipher_result = RejectCipher(OPENSSL_TO_RFC_NAMES_MAPPING[PROTOCOL_VERSION[ssl_version]].get(cipher,cipher), str(e))
        else:
            cipher_result = AcceptCipher(OPENSSL_TO_RFC_NAMES_MAPPING[PROTOCOL_VERSION[ssl_version]].get(cipher,cipher))
        finally:
            ssl_conn.close()
        return cipher_result

    def _test_dtls_ciphersuite(self, server_connectivity_info, dtls_version, cipher, port):
        """This function is used by threads to it investigates with support the cipher suite on server, when DTLS protocol(s) is/are tested. Returns instance of class AcceptCipher or RejectCipher.

            Args:
            server_connectivity_info (ServerConnectivityInfo): contains information for connection on server
            dtls_version (str): contains SSL/TLS protocol version, which is used to connect
            cipher (str): contains OpenSSL shortcut for identification cipher suite
            port (int): contains port number for connecting comunication.
        """
        cnx = SSL.Context(dtls_version)
        cnx.set_cipher_list(cipher)
        conn = SSL.Connection(cnx, socket.socket(socket.AF_INET, socket.SOCK_DGRAM))
        try:
            conn.connect((server_connectivity_info.ip_address, port))
            conn.do_handshake()
        except SSL.Error as e:
            error_msg = ((e[0])[0])[2]
            cipher_result = RejectCipher(TLS_OPENSSL_TO_RFC_NAMES_MAPPING[cipher], error_msg)
        else:
            cipher_result = AcceptCipher(TLS_OPENSSL_TO_RFC_NAMES_MAPPING[cipher])
        finally:
            conn.shutdown()
            conn.close()
        return cipher_result

    def test_dtls_protocol_support(self, server_connectivity_info, dtls_version, port):
        """Tests if DTLS protocols are supported by server. Returns true if server supports protocol otherwise returns false.
    
            Args:
            server_connectivity_info (ServerConnectivityInfo): contains information for connection on server
            dtls_protocol (str): contains version of DTLS protocol, which is supposed to be tested
            port (int): contains port number for connecting comunication.
        """
        cnx = SSL.Context(dtls_version)
        cnx.set_cipher_list('ALL:COMPLEMENTOFALL')
        conn = SSL.Connection(cnx,socket.socket(socket.AF_INET, socket.SOCK_DGRAM))
        try:
            conn.connect((server_connectivity_info.ip_address, port))
            conn.do_handshake()
        except SSL.SysCallError as ex:
            if ex[0] == 111:
                raise ValueError('LuckyThirteenVulnerabilityTesterPlugin: It is entered wrong port for DTLS connection.')
            else:
                support = False
        else:
            support = True
        finally:
            conn.shutdown()
            conn.close()
        return support

    def get_support_dtls_protocols_by_client(self):
        """Returns array which contains all DTLS protocols which are supported by client.
        """
        dtls_ary = []
        for dtls_version in self.DTLS_MODULE:
            try:
                SSL.Context._methods[dtls_version]=getattr(_lib, self.DTLS_MODULE[dtls_version])
            except Exception as e:
                pass
            else:
                dtls_ary.append(dtls_version)
        return dtls_ary


    def test_protocol_support(self, server_connectivity_info, ssl_protocol):
        """Tests if SSL/TLS protocols are supported by server. Returns true if server supports protocol otherwise returns false.
    
            Args:
            server_connectivity_info (ServerConnectivityInfo): contains information for connection on server
            ssl_protocol (str): contains version of SSL/TLS protocol, which is supposed to be tested.            
        """
        if server_connectivity_info.highest_ssl_version_supported < PROTOCOL_VERSION[ssl_protocol]:
            return False
        ssl_conn = server_connectivity_info.get_preconfigured_ssl_connection(override_ssl_version=PROTOCOL_VERSION[ssl_protocol])
        protocol_supported = True
        try:
            ssl_conn.connect()
        except SSLHandshakeRejected:
            protocol_supported = False
        finally:
            ssl_conn.close()
        return protocol_supported

    def get_dtls_cipher_list(self, dtls_version):
        """Returns list of cipher suites available for protocol version, saves in parameter dtls_protocol.
    
            Args:
            dtls_protocol (str):.
        """
        cnx = SSL.Context(dtls_version)
        cnx.set_cipher_list('ALL:COMPLEMENTOFALL')
        conn = SSL.Connection(cnx, socket.socket(socket.AF_INET,socket.SOCK_DGRAM))
        return conn.get_cipher_list()

    def get_cipher_list(self, ssl_protocol):
        """Returns list of cipher suites available for protocol version, saves in parameter ssl_protocol.
    
            Args:
            ssl_protocol (str):.
        """
        ssl_client = SslClient(ssl_version=PROTOCOL_VERSION[ssl_protocol])
        ssl_client.set_cipher_list('ALL:COMPLEMENTOFALL')
        return ssl_client.get_cipher_list()

    def get_vulnerable_cipher_set(self, accept_ciphers):
        """Returns set of cipher suites, which using CBC mode and are supported by server.
    
            Args:
            accept_ciphers(array): contains cipher suites, which are supported by server.
        """
        result_set = Set()
        for cipher in accept_ciphers:
            if cipher.use_CBC_mode():
                result_set = result_set.union([cipher._cipher_rfc_name])
        return result_set

class LuckyThirteenVulnerabilityTesterResult(PluginResult):
    """@class LuckyThirteenVulnerabilityTesterResult
    
    This class is subclass PluginResult. It's used to return result of test, which is made in class LuckyThirteenVulnerabilityTesterPlugin.
    """
    def __init__(self, server_connectivity_info, plugin_command, plugin_option, is_vulnerable, support_vulnerable_ciphers, adv_info):
        super(LuckyThirteenVulnerabilityTesterResult, self).__init__(server_connectivity_info, plugin_command, plugin_option)
        self.is_vulnerable = is_vulnerable
        self.support_vulnerable_ciphers = support_vulnerable_ciphers
        self.adv_info = adv_info
        self.ip_address = server_connectivity_info.ip_address

    COMMAND_TITLE = 'Vulnerability CVE-2013-0169'
    CIPHER_LIST_TITLE_FORMAT = '      {section_title:<32}'.format
    CIPHER_LINE_FORMAT = u'        {cipher_name:<50}'.format

    def as_text(self):
        if self.adv_info['protocol'] == 'DTLS':
            txt_output = ['SCAN IS DONE FOR DTLS PROTOCOLS - {ip_address}:{port}'.format(ip_address=self.ip_address,port=self.adv_info['port'])]
            txt_output.append('-' * len(txt_output[0]))
        else:
            txt_output = []
        lucky_txt = 'VULNERABLE - server is vulnerable to Lucky13 attack' \
            if self.is_vulnerable \
            else 'OK - Not vulnerable to Lucky13 attack'

        txt_output.append(self.PLUGIN_TITLE_FORMAT(self.COMMAND_TITLE))
        txt_output.append(self.FIELD_FORMAT("", lucky_txt))
        if self.is_vulnerable:
            txt_output.append(self.CIPHER_LIST_TITLE_FORMAT(section_title='Vulnerable cipher/ciphers:'))
        for c in self.support_vulnerable_ciphers:
            txt_output.append(self.CIPHER_LINE_FORMAT(cipher_name=c))
        return txt_output
    def as_xml(self):
        xml_output = Element(self.plugin_command,title=self.COMMAND_TITLE)
        if self.adv_info['protocol'] == 'DTLS':
            xml_output.append(Element('diffConnectionInfo',attrib={ 'protocol': self.adv_info['protocol'],
                                                    'ip_address': self.ip_address,
                                                    'port': self.adv_info['port']}))
        
        xml_output.append(Element('vulnerable', isVulnerable=str(self.is_vulnerable)))

        if len(self.support_vulnerable_ciphers) > 0:
            ciphers_xml = Element('vulnerableCipherSuites')
            for c in self.support_vulnerable_ciphers:
                ciphers_xml.append(Element('cipherSuite',name=c))
            xml_output.append(ciphers_xml)
        return xml_output